/*
 * Copyright 2021-2022 Open Kunlun Technology <https://www.openkunlun.io>
 */

package io.openkunlun.scaladsl.server

import akka.NotUsed
import akka.actor.{ ActorSystem, Cancellable }
import akka.event.{ Logging, LoggingAdapter }
import akka.grpc.GrpcClientSettings
import akka.http.scaladsl.{ ConnectionContext, Http, HttpsConnectionContext }
import akka.pki.pem.{ DERPrivateKeyLoader, PEMDecoder }
import akka.stream.Materializer
import akka.stream.scaladsl.Source
import io.openkunlun.scaladsl.v1.{ AbolishRequest, DaprAppPowerApiHandler, DaprClient, EstablishRequest, KeepAliveRequest, KeepAliveResponse }

import java.io.InputStream
import java.security.cert.{ Certificate, X509Certificate }
import java.security.{ KeyStore, PrivateKey, SecureRandom }
import java.util.concurrent._
import javax.net.ssl.{ KeyManagerFactory, SSLContext, TrustManagerFactory }
import scala.concurrent.duration.{ Duration, DurationInt, DurationLong, FiniteDuration }
import scala.concurrent.{ Await, ExecutionContext, Future }
import scala.util.{ Failure, Success }

/**
 * @author ericxin.
 */
class DaprServer(system: ActorSystem)(implicit ec: ExecutionContext = system.dispatcher) {

  private val log: LoggingAdapter = Logging(system, getClass)
  private val app: DaprApp = new DaprApp(system)(ec)

  private val mat: Materializer = Materializer(system).withNamePrefix("DaprServer")

  private val daprAppSettings: DaprAppSettings = new DaprAppSettings(system.settings.config)

  private val serverIdleTimeout: Duration = system.settings.config.getDuration("io.openkunlun.server.idle-timeout", TimeUnit.SECONDS).second

  private val clientSettings: GrpcClientSettings = GrpcClientSettings.fromConfig("dapr.grpc")(system)
  @volatile private var client: DaprClient = getClient

  implicit val sys: ActorSystem = system

  private var https: Boolean = false

  def isHttps: Boolean = https

  private var connectionDelay: FiniteDuration = Duration.Zero

  /**
   *
   * @param handler
   */
  def addHandler(handler: DaprHandler): Unit = {
    app.addHandler(handler)
  }

  /**
   *
   * @param handlers
   */
  def addHandlers(handlers: Seq[DaprHandler]): Unit = {
    app.addHandlers(handlers)
  }

  /**
   *
   * @param handler
   */
  def addBindingHandler(handler: BindingHandler): Unit = {
    app.addBindingHandler(handler)
  }

  /**
   *
   * @param handlers
   */
  def addBindingHandlers(handlers: Seq[BindingHandler]): Unit = {
    app.addBindingHandlers(handlers)
  }

  /**
   *
   * @param handler
   */
  def addInvocationHandler(handler: InvocationHandler): Unit = {
    app.addInvocationHandler(handler)
  }

  /**
   *
   * @param handlers
   */
  def addInvocationHandlers(handlers: Seq[InvocationHandler]): Unit = {
    app.addInvocationHandlers(handlers)
  }

  /**
   *
   * @return
   */
  def start(): Future[Http.ServerBinding] = {
    val (host, port) = getHostAndPort()
    start(host, port)
  }

  def start(host: String, port: Int): Future[Http.ServerBinding] = {
    run(host, port, None)
  }

  def start(pkcs12: InputStream): Future[Http.ServerBinding] = {
    val (host, port) = getHostAndPort()
    start(host, port, pkcs12)
  }

  def start(host: String, port: Int, pkcs12: InputStream): Future[Http.ServerBinding] = {
    val httpCtx = buildHttpsConnectionContext(pkcs12, None)
    run(host, port, Some(httpCtx))
  }

  def start(pkcs12: InputStream, password: String): Future[Http.ServerBinding] = {
    val (host, port) = getHostAndPort()
    start(host, port, pkcs12, password)
  }

  def start(host: String, port: Int, pkcs12: InputStream, password: String): Future[Http.ServerBinding] = {
    val httpCtx = buildHttpsConnectionContext(pkcs12, Some(password))
    run(host, port, Some(httpCtx))
  }

  def start(privateKey: String, keyCertChain: X509Certificate): Future[Http.ServerBinding] = {
    val (host, port) = getHostAndPort()
    start(host, port, privateKey, keyCertChain)
  }

  def start(host: String, port: Int, privateKey: String, keyCertChain: X509Certificate): Future[Http.ServerBinding] = {
    val httpCtx = buildHttpsConnectionContext(privateKey, keyCertChain, None)
    run(host, port, Some(httpCtx))
  }

  def start(privateKey: String, keyPassword: String, keyCertChain: X509Certificate): Future[Http.ServerBinding] = {
    val (host, port) = getHostAndPort()
    start(host, port, privateKey, keyPassword, keyCertChain)
  }

  def start(host: String, port: Int, privateKey: String, keyPassword: String, keyCertChain: X509Certificate): Future[Http.ServerBinding] = {
    val httpCtx = buildHttpsConnectionContext(privateKey, keyCertChain, Some(keyPassword))
    run(host, port, Some(httpCtx))
  }

  private def getHostAndPort(): (String, Int) = {
    (daprAppSettings.host, daprAppSettings.port)
  }

  private def run(host: String, port: Int, httpCtx: Option[HttpsConnectionContext]): Future[Http.ServerBinding] = {
    val serverBuilder = Http(system).newServerAt(interface = host, port = port).adaptSettings(it => it.withTimeouts(it.withIdleTimeout(serverIdleTimeout))).logTo(log)
    val serverBinding: Future[Http.ServerBinding] = if (httpCtx.nonEmpty) {
      https = true
      serverBuilder
        .enableHttps(httpCtx.get)
        .bind(DaprAppPowerApiHandler(app))
        .map(_.addToCoordinatedShutdown(hardTerminationDeadline = 10.seconds))
    } else {
      serverBuilder
        .bind(DaprAppPowerApiHandler(app))
        .map(_.addToCoordinatedShutdown(hardTerminationDeadline = 10.seconds))
    }

    (for {
      result <- serverBinding
      _ <- runDaprApp
    } yield result).onComplete {
      case Success(binding) =>
        val address = binding.localAddress
        log.warning("Dapr server started. listen on {}://{}:{}", if (https) "https" else "http", address.getHostString, address.getPort)
      case Failure(ex) =>
        log.warning("Failed to bind endpoint, terminating system", ex)
        Await.result(system.terminate(), Duration.Inf)
    }
    serverBinding
  }

  private def runDaprApp: Future[Unit] = {
    if (daprAppSettings.roles.nonEmpty || daprAppSettings.bindings.nonEmpty) {
      for {
        _ <- startDaprApp()
        _ = system.registerOnTermination(stopDaprApp())
      } yield ()
    } else Future.unit
  }

  private def getClient: DaprClient = {
    DaprClient(clientSettings.withDeadline(Duration.Inf).withTls(clientSettings.trustManager.nonEmpty))(system)
  }

  private var watcher: Option[Cancellable] = None

  private def startDaprApp(): Future[_] = {
    if (daprAppSettings.roles.nonEmpty || daprAppSettings.bindings.nonEmpty) {
      (for {
        response <- client.establishService(
          EstablishRequest(
            daprAppSettings.id,
            daprAppSettings.host,
            daprAppSettings.port,
            if ("http".equalsIgnoreCase(daprAppSettings.protocol)) EstablishRequest.Protocol.HTTP else EstablishRequest.Protocol.GRPC,
            daprAppSettings.roles,
            daprAppSettings.bindings,
            daprAppSettings.weight,
            daprAppSettings.warmup,
            System.currentTimeMillis(),
            daprAppSettings.ttl
          )
        )
        _ = connectionDelay = Duration.Zero
        _ = keepAliveDaprApp()
        _ = log.warning("Dapr app started. [id: {}, ttl: {}, roles: {}, bindings: {}]", response.id, response.ttl, daprAppSettings.roles.mkString("[", ",", "]"), daprAppSettings.bindings.mkString("[", ",", "]"))
      } yield ()) recoverWith {
        case e =>
          log.error("Start dapr app failed. cause: [{}].", Option(e.getCause).map(_.toString).getOrElse(e.getMessage))
          restartDaprApp()
          Future.unit
      }
    } else Future.unit
  }

  private def stopDaprApp(): Future[Unit] = {
    for {
      _ <- client.abolishService(AbolishRequest(daprAppSettings.id))
      _ <- client.close()
      _ = watcher.foreach(_.cancel())
    } yield ()
  }

  private def restartDaprApp(): Unit = {
    log.warning("Dapr app restarting......")
    client.close()
    client = getClient
    watcher.foreach(_.cancel())
    watcher = Some(system.scheduler.scheduleOnce(connectionDelay)(startDaprApp()))
    connectionDelay = if (connectionDelay.plus(daprAppSettings.connectionIdleInterval) > daprAppSettings.connectionMaxTimeout) daprAppSettings.connectionIdleInterval else connectionDelay.plus(daprAppSettings.connectionIdleInterval)
  }

  private def keepAliveDaprApp(): Unit = {
    val requestStream: Source[KeepAliveRequest, NotUsed] = Source.tick(daprAppSettings.keepAliveInterval, daprAppSettings.keepAliveInterval, KeepAliveRequest(daprAppSettings.id))
      .mapMaterializedValue(_ => NotUsed)

    val responseStream: Source[KeepAliveResponse, NotUsed] = client.keepAliveService(requestStream)
    responseStream.runForeach(it => log.info("Keep alive dapr app. [id: {}, ttl: {}]", it.id, it.ttl))(mat) onComplete {
      case Success(_) =>
        log.warning("Keep alive interrupted. restart dapr app.")
        restartDaprApp()
      case Failure(e) =>
        log.error("Keep alive failed. cause: [{}]. restart dapr app.", Option(e.getCause).map(_.toString).getOrElse(e.getMessage))
        restartDaprApp()
    }
  }

  private val PKCS12: String = "PKCS12"
  private val TLS: String = "TLS"
  private val SunX509: String = "SunX509"
  private val emptyPassword: Array[Char] = new Array[Char](0)

  private def buildHttpsConnectionContext(privateKey: String, keyCertChain: X509Certificate, keyPassword: Option[String]): HttpsConnectionContext = {
    val pk: PrivateKey = DERPrivateKeyLoader.load(PEMDecoder.decode(privateKey))
    val ks = KeyStore.getInstance(PKCS12)
    ks.load(null)
    ks.setKeyEntry(
      "private",
      pk,
      keyPassword.map(_.toCharArray).getOrElse(emptyPassword),
      Array[Certificate](keyCertChain)
    )
    val keyManagerFactory = KeyManagerFactory.getInstance(SunX509)
    keyManagerFactory.init(ks, null)
    val context = SSLContext.getInstance(TLS)
    context.init(keyManagerFactory.getKeyManagers, null, new SecureRandom)
    ConnectionContext.httpsServer(context)
  }

  private def buildHttpsConnectionContext(pkcs12: InputStream, password: Option[String]): HttpsConnectionContext = {
    val passwordCharArray = password.map(_.toCharArray).getOrElse(emptyPassword)
    val ks: KeyStore = KeyStore.getInstance(PKCS12)
    ks.load(pkcs12, passwordCharArray)
    val keyManagerFactory: KeyManagerFactory = KeyManagerFactory.getInstance(SunX509)
    keyManagerFactory.init(ks, passwordCharArray)

    val tmf: TrustManagerFactory = TrustManagerFactory.getInstance(SunX509)
    tmf.init(ks)

    val sslContext: SSLContext = SSLContext.getInstance(TLS)
    sslContext.init(keyManagerFactory.getKeyManagers, tmf.getTrustManagers, new SecureRandom)
    ConnectionContext.httpsServer(sslContext)
  }
}